import torch
torch.set_default_dtype(torch.float64)
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time, argparse
from pprint import pprint
import json
from utils import dict_agg, set_seed
import default_args
from helper_new_portfolio import NNPrimalSolver, NNDualSolver, load_portfolio_data, load_portfolio_dyn_data
from dataset import Dataset as D
import pandas as pd
from pathlib import Path
from NSDE_training import NeuralSDE
import torchsde

CURRENT_PATH = Path(__file__).absolute().parent

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("DEVICE", DEVICE, flush=True)

def main():
    parser = argparse.ArgumentParser(description='Baseline Unsupervised')
    parser.add_argument('--seed', type=int, default=1001, help='random seed')
    parser.add_argument('--probtype', type=str, default='portfolio', choices=['portfolio'], help='problem type')

    # QP cases specific parameters
    parser.add_argument('--nvar', type=int, help='number of decision variables')
    parser.add_argument('--nineq', type=int, help='number of inequality constraints')
    parser.add_argument('--neq', type=int, help='number of equality constraints')
    parser.add_argument('--nex', type=int, help='total number of data instances')
    parser.add_argument('--use_sigmoid', type=bool, help='whether to apply a sigmoid to the last layer')
    parser.add_argument('--useLSTM', type=bool, default=False, help='whether to apply a sigmoid to the last layer')
    parser.add_argument('--LSTMmodelindex', type=int, default=29, help='whether to apply a sigmoid to the last layer')

    # SL baseline parameters
    parser.add_argument('--maxouteriter', type=int, help='maximum outer iterations')
    parser.add_argument('--losstype', type=str, default='msep', choices = ['mae', 'mse', 'maep', 'msep', 'ld'], help='MAE or MSE')
    parser.add_argument('--ldupdatefreq', type=int, help='LD penalty coefficient update epoch frequency')
    parser.add_argument('--ldstepsize', type=float, help='LD multiplier update step size')
    parser.add_argument('--use_feasibility_restoration', type=bool, help='whether to apply the feasibility restoration procedure or not')

    # Related to training & neural nets
    parser.add_argument('--batchsize', type=int, help='training batch size')
    parser.add_argument('--epochs', type=int, help='number of epochs')
    parser.add_argument('--lr', type=float, help='neural network learning rate')
    parser.add_argument('--hiddensize', type=int, default=500, help='hidden layer size for neural network')
    parser.add_argument('--hiddenfrac', type=float, help='hidden layer node fraction (only used for ACOPF)')

    # JK removed default=1.2 ^
    parser.add_argument('--nlayer', type=int, help='the number of layers')
    parser.add_argument('--lamg', type=float, help='penalty coefficient for inequality constraints')
    parser.add_argument('--lamh', type=float, help='penalty coefficient for equality constraints')
    parser.add_argument('--save', type=bool, default=False, help='whether to save statistics')

    # JK
    parser.add_argument('--objscaler', type=bool, default=None, help='objective scaling factor')
    parser.add_argument('--index', type=int, help='index to keep track of different runs')

    args = vars(parser.parse_args()) # to dictionary
    args_default = default_args.baseline_supervised_default_args(args['probtype'],args['losstype'])
    for k,v in args_default.items():
        args[k] = v if args[k] is None else args[k]
    pprint(args)

    set_seed(args['seed'],DEVICE)

    if 'portfolio' in args['probtype']:
        data, args = load_portfolio_dyn_data(args, CURRENT_PATH, DEVICE)
    else:
        raise NotImplementedError
    print("Loading Data Done Successfully:", str(data))

    tstart = time.time()
    out, net, best_results, best_results_batch = train_net(data, args)

    #if 'predopt' in args['probtype']:
    #    tmp = "_losstype_" + str(args['losstype']) + "_lamg_" + str(args['lamg']) + "_lamh_" + str(
    #        args['lamh']) + "_ldupdatefreq_" + str(args['ldupdatefreq']) + "_ldstepsize_" + str(
    #        args['ldstepsize']) + "_probtype_" + str(args['probtype']) + "_featNet_nlayer_" + str(
    #        args['featNet_nlayer']) + "_acopf_feature_mapping_type_" + str(args['acopf_feature_mapping_type']) + ".txt"
    #else:
    #    tmp = "_losstype_" + str(args['losstype']) + "_lamg_" + str(args['lamg']) + "_lamh_" + str(
    #        args['lamh']) + "_ldupdatefreq_" + str(args['ldupdatefreq']) + "_ldstepsize_" + str(
    #        args['ldstepsize']) + "_probtype_" + str(args['probtype']) + + "_acopf_feature_mapping_type_" + str(args['acopf_feature_mapping_type']) + ".txt"

    out_path = "out/"

    if 'acopf' not in args['probtype']:
        out_path = "out/feasibility_restoration/baseline_supervised/"

    use_feasibility_restoration = False
    if use_feasibility_restoration == False: #and 'acopf' not in args['probtype']:
        data = {
            #'Probtype': [args['probtype']],
            #'Lamg': [args['lamg']],
            #'lamh' : [args['lamh']],
            #'ldupdatefreq' : [args['ldupdatefreq']],
            #'ldstepsize' :[args['ldstepsize']],
            #'lr' : [ args['lr']],
            #'nlayer' : [args['nlayer']],
            #'Seed': [args['seed']],
            
            'Seed': [args['seed']],
            'Losstype': [args['losstype']],
            'Max eq. viol': [best_results[0]],
            'Mean eq. viol': [best_results[1]],
            'Max ineq. viol.': [best_results[2]],
            'Mean ineq. viol.': [best_results[3]],
            'Mean opt.gap': [best_results[4]],
            'Mean opt.gap batch': [best_results_batch[5]],

            'Max eq. viol batch': [best_results_batch[0]],
            'Mean eq. viol batch': [best_results_batch[1]],
            'Max ineq. viol. batch': [best_results_batch[2]],
            'Mean ineq. viol. batch': [best_results_batch[3]],
            'Mean opt.gap per batch': [best_results_batch[4]],
            'Mean opt.gap batch per batch': [best_results_batch[5]],

        }
        print("Record saved.")
        #column_names = ['Hyperparams', 'Probtype', 'Seed', 'Opt.gap','Max eq. viol', 'Mean eq. viol']
        df = pd.DataFrame(data)
        df.to_csv('supervised_static_results.csv', mode='a', header=False, index=False)

    print("Train is done, Elapsed Time %.2fs"%(time.time()-tstart), flush=True)
    #helper_new_portfolio.py
    #torch.save(net.state_dict(),f'proxy_model/{args["probtype"]}_{args["nlayer"]}.pt')
    if args['save']:
        save_dir = CURRENT_PATH/"results"/args['probtype']
        save_dir.mkdir(exist_ok=True, parents=True)
        save_name = "Supervised_loss%s_%s_s%d.chpt"%(args['losstype'],str(data),args['seed'])
        save_fn = save_dir/save_name
        save_dict = {
            'net': net.to('cpu').state_dict(),
            'args': args,
            'out': out
        }
        torch.save(save_dict, save_fn)

    return None


def train_net(data, args):

    train_dataset = BaselineDataSet(data.trainX, data.trainY)
    valid_dataset = BaselineDataSet(data.validX, data.validY)
    test_dataset = BaselineDataSet(data.testX, data.testY)

    train_loader = DataLoader(train_dataset, batch_size=args['batchsize'], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

    if 'portfolio' in args['probtype']:
        #net = NNPrimalPortfolioSolver(data, args).to(DEVICE)
        net = NNPrimalSolver(data, args).to(DEVICE)
    else:
        raise NotImplementedError
    net.train()
    optimizer = optim.Adam(net.parameters(), lr=args['lr'])

    lamh = args['lamh']*torch.ones(args['neq'],dtype=torch.get_default_dtype()).to(DEVICE)
    lamg = args['lamg']*torch.ones(args['nineq'],dtype=torch.get_default_dtype()).to(DEVICE)

    feature_Generator, parameter_regressor_Net = 0, 0
    init_X_train = np.load(f'portfolio_data/init_asset_prices_training_{int(.8*args["nex"])}.npy').T

    global min_opt_gap, min_opt_gap_batch
    min_opt_gap, min_opt_gap_batch = 100000,100000
    global n_epochs
    n_epochs = 0

    for i in range(args['maxouteriter']):
        n_epochs += 1
        for i in range(args['epochs']):
            t0 = time.time()
            epoch_stats = {}
            train_loss_ = 0.
            net.train()
            for j, (Xtrain, Ytrain) in enumerate(train_loader):
                optimizer.zero_grad()
                #Xfeat = Xtrain
                #print(Xfeat[0,:])
                Xfeat = torch.from_numpy(init_X_train[j*args['batchsize']:(j+1)*args['batchsize'], :]).double()
                #print(Xfeat[0,:])
                Yhat_train = net(Xfeat)
                #print(Yhat_train[0,:])
                #print(Yhat_train[0,:])
                train_loss = total_loss(data, Xtrain, Ytrain, Yhat_train, lamg, lamh, args)
                train_loss.mean().backward()
                #print("Loss: ", train_loss.mean().item())
                optimizer.step()
                train_loss_ += train_loss.mean().item()
            train_loss_ /= (len(train_loader))
            t1 = time.time()
            net.eval()

            #print("Lambda h size: ",lamh.size())
            #print("Lambda g size: ",lamg.size())

            for Xvalid, Yvalid in valid_loader:
                epoch_stats = eval_net(data, Xvalid, Yvalid, net, feature_Generator, parameter_regressor_Net,  lamg, lamh, args, 'valid', epoch_stats, -1)
            if i%10 == 0 and i>0:
                print("P Epoch:%05d | loss:%.4f | time:%.4fs"%(
                    i, train_loss_, t1-t0
                ), flush=True)
                print("        valid | loss:%.4f | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
                    np.mean(epoch_stats['valid_primal_loss']), np.mean(epoch_stats['valid_eval']),
                    np.mean(epoch_stats['valid_ineq_max']), np.mean(epoch_stats['valid_ineq_mean']),
                    np.mean(epoch_stats['valid_eq_max']),np.mean(epoch_stats['valid_eq_mean'])
                ), flush=True)
            if args['losstype'] == 'ld' and i%args['ldupdatefreq']==0 and i>0:
                lamg, lamh = update_lamda(train_loader, net, feature_Generator, data, lamg, lamh, args)

        for j, (Xtrain, Ytrain) in enumerate(train_loader):
            epoch_stats = eval_net(data, Xtrain, Ytrain, net, feature_Generator, parameter_regressor_Net,  lamg, lamh, args, 'train', epoch_stats, j)

        for X, Y in test_loader:
            epoch_stats = eval_net(data, X, Y, net, feature_Generator, parameter_regressor_Net,  lamg, lamh, args, 'test', epoch_stats, -1)
            epoch_stats = eval_net(data, X, Y, net, feature_Generator, parameter_regressor_Net,  lamg, lamh, args, 'test_gt', epoch_stats, -1)

        tmp = epoch_stats['test_opt_gap']
        print("Quantile of opt. gap", np.quantile(tmp, [0.25, 0.5, 0.75, 1]))
        print("        train |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
            np.mean(epoch_stats['train_eval']),
            np.mean(epoch_stats['train_ineq_max']), np.mean(epoch_stats['train_ineq_mean']),
            np.mean(epoch_stats['train_eq_max']),np.mean(epoch_stats['train_eq_mean'])
        ), flush=True)
        print("         test |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f | optgap max:%.4f mean:%.4f " % (
                np.mean(epoch_stats['test_eval']),
                np.mean(epoch_stats['test_ineq_max']), np.mean(epoch_stats['test_ineq_mean']),
                np.mean(epoch_stats['test_eq_max']), np.mean(epoch_stats['test_eq_mean']),
                100 * np.max(epoch_stats['test_opt_gap']), 100 * np.mean(epoch_stats['test_opt_gap']),
                
            ))  # , flush=True)
        print("      test gt |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
            np.mean(epoch_stats['test_gt_eval']),
            np.mean(epoch_stats['test_gt_ineq_max']), np.mean(epoch_stats['test_gt_ineq_mean']),
            np.mean(epoch_stats['test_gt_eq_max']),np.mean(epoch_stats['test_gt_eq_mean'])
        ), flush=True)


        out = {
            'obj': np.mean(epoch_stats['test_eval']),
            'eq_max': np.mean(epoch_stats['test_eq_max']),
            'ineq_max': np.mean(epoch_stats['test_ineq_max']),
            'eq_mean': np.mean(epoch_stats['test_eq_mean']),
            'ineq_mean': np.mean(epoch_stats['test_ineq_mean']),
            'opt_gap_max': 100*np.max(epoch_stats['test_opt_gap']),
            'opt_gap_mean': 100*np.mean(epoch_stats['test_opt_gap'])
        }
        #print(epoch_stats['test_est_obj'])
        #print(epoch_stats['test_gt_obj'])

        if  100 * np.mean(epoch_stats['test_opt_gap']) < min_opt_gap:
            mean_opt_gap = 100 * np.mean(epoch_stats['test_opt_gap'])
            batch_opt_gap = 100*abs(epoch_stats['test_est_obj'] - epoch_stats['test_gt_obj'])/abs(epoch_stats['test_gt_obj'])
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            max_ineq = out['ineq_max']
            mean_ineq = out['ineq_mean']
            best_results = [max_eq, mean_eq, max_ineq, mean_ineq, mean_opt_gap, batch_opt_gap]

        if  100*abs(epoch_stats['test_est_obj'] - epoch_stats['test_gt_obj'])/abs(epoch_stats['test_gt_obj']) < min_opt_gap_batch:
            mean_opt_gap = 100 * np.mean(epoch_stats['test_opt_gap'])
            batch_opt_gap = 100*abs(epoch_stats['test_est_obj'] - epoch_stats['test_gt_obj'])/abs(epoch_stats['test_gt_obj'])
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            max_ineq = out['ineq_max']
            mean_ineq = out['ineq_mean']
            best_results_batch = [max_eq, mean_eq, max_ineq, mean_ineq, mean_opt_gap, batch_opt_gap]

    return out, net, best_results, best_results_batch


def eval_net(data, X, Ygt, net, feature_Generator, parameter_regressor_Net, lamg, lamh, args, prefix, stats, j):
    torch.set_grad_enabled(False)
    make_prefix = lambda x: "{}_{}".format(prefix, x)
    mse = torch.nn.MSELoss()
    start_time = time.time()

    init_X_valid = np.load(f'portfolio_data/init_asset_prices_validation_{int(.1*args["nex"])}.npy') 
    init_X_test = np.load(f'portfolio_data/init_asset_prices_test_{int(.1*args["nex"])}.npy') 
    init_X_train = np.load(f'portfolio_data/init_asset_prices_training_{int(.8*args["nex"])}.npy')

    if prefix == 'test_gt':
        Y = data.testY
    elif prefix == 'test':
        Ygt = data.testY
        X_hat = torch.from_numpy(init_X_test.T).double()
        print("MSE test: ", mse(X_hat,X))
        Y = net(X_hat)
        #Y = net(X)
    elif prefix == 'valid':
        X_hat = torch.from_numpy(init_X_valid.T).double()
        #print("MSE test: ", mse(X_hat,X))
        Y = net(X_hat)
        #Y = net(X)
    else:
        X_hat = torch.from_numpy(init_X_train.T[j*args['batchsize']:(j+1)*args['batchsize'], :]).double()
        #print("MSE test: ", mse(X_hat,X))
        Y = net(X_hat)
    
    '''
    ORIGINAL VERSION
    if prefix == 'test_gt':
        Y = data.testY
    else:
        Y = net(X)
    '''

    eqval = data.eq_resid(X, Y).float()
    ineqval = data.ineq_dist(X, Y)

    if args['use_feasibility_restoration'] == True:
        if 'portfolio' == args['probtype'] and prefix == 'test':
            if torch.count_nonzero(torch.abs(ineqval)).item() > 0 :
                Y = torch.clamp(Y, min = 0)
                ineqval = data.ineq_dist(X, Y)
            print("Inequality violations: ", torch.count_nonzero(torch.abs(ineqval)).item())
            eqval = data.eq_resid(X, Y).float()
            if torch.count_nonzero(torch.abs(eqval)).item() > 0 :
                Y = Y/Y.sum(dim=1, keepdim=True)
                eqval = data.eq_resid(X, Y).float()
            print("Equality violations: ", torch.count_nonzero(torch.abs(eqval)).item())

    end_time = time.time()
    dict_agg(stats, make_prefix('time'), end_time - start_time, op='sum')
    dict_agg(stats, make_prefix('eval'), data.obj_fn(X,Y).detach().cpu().numpy()*data.obj_scaler)
    dict_agg(stats, make_prefix('primal_loss'), total_loss(data, X, Ygt, Y, lamg, lamh, args).detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_max'), torch.max(data.ineq_dist(X, Y), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_mean'), torch.mean(data.ineq_dist(X, Y), dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_max'), torch.max(torch.abs(data.eq_resid(X, Y)), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_mean'), torch.mean(torch.abs(data.eq_resid(X, Y)), dim=1).detach().cpu().numpy())

    if 'gt' not in prefix and 'test' in prefix:
        dict_agg(stats, make_prefix('opt_gap'), data.opt_gap(X,Y,Ygt).detach().cpu().numpy())
        dict_agg(stats, make_prefix('est_obj'), data.obj_fn(X,Y).detach().cpu().numpy().mean())
    elif 'gt' in prefix and 'test' in prefix:
        dict_agg(stats, make_prefix('obj'), data.obj_fn(X,Ygt).detach().cpu().numpy().mean())


    #if prefix == 'test':
    #    dict_agg(stats, make_prefix('opt_gap_after'), data.opt_gap(X, Y, Ygt).detach().cpu().numpy())    

    #end_time = time.time()

    # save
    #dict_agg(stats, make_prefix('time'), end_time - start_time, op='sum')
    #dict_agg(stats, make_prefix('eval'), data.obj_fn(X,Y).detach().cpu().numpy()*data.obj_scaler)

    #dict_agg(stats, make_prefix('primal_loss'), total_loss(data, X, Ygt, Y, lamg, lamh, args).detach().cpu().numpy())
    #dict_agg(stats, make_prefix('ineq_max'), torch.max(data.ineq_dist(X, Y), dim=1)[0].detach().cpu().numpy())
    #dict_agg(stats, make_prefix('ineq_mean'), torch.mean(data.ineq_dist(X, Y), dim=1).detach().cpu().numpy())
    #dict_agg(stats, make_prefix('eq_max'), torch.max(torch.abs(data.eq_resid(X, Y)), dim=1)[0].detach().cpu().numpy())
    #dict_agg(stats, make_prefix('eq_mean'), torch.mean(torch.abs(data.eq_resid(X, Y)), dim=1).detach().cpu().numpy())

    torch.set_grad_enabled(True)
    return stats


def total_loss(data, X, Ygt, Y, lamg, lamh, args):
    ineq_val = data.ineq_resid(X,Y)
    eq_val = data.eq_resid(X,Y)

    if 'mae' in args['losstype'] or args['losstype'] == 'ld':
        if args['probtype'] in ['predopt_nonconvexqp', 'convexqp', 'portfolio_qcqp', 'predopt_portfolio_qcqp' ,'nonconvexqp', 'qcqp', 'bilinear', 'predopt_bilinear'] + ['knapsack'] + ['portfolio'] + ['predopt_portfolio'] + ['predopt_knapsack']:
            loss = (Ygt.to(DEVICE)-Y.to(DEVICE)).abs().mean(dim=1)
        elif 'acopf' in args['probtype']:
            pg_loss = (Ygt['pg'].to(DEVICE)-Y['pg'].to(DEVICE)).abs().mean(dim=1)
            qg_loss = (Ygt['qg'].to(DEVICE)-Y['qg'].to(DEVICE)).abs().mean(dim=1)
            vm_loss = (Ygt['vm'].to(DEVICE)-Y['vm'].to(DEVICE)).abs().mean(dim=1)
            dva_loss = (Ygt['dva'].to(DEVICE)-Y['dva'].to(DEVICE)).abs().mean(dim=1)
            loss = 0.25*(pg_loss + qg_loss + vm_loss + dva_loss)
        else:
            raise NotImplementedError
        eq_viols = eq_val.abs()
        ineq_viols = torch.clamp(ineq_val, min=0.)

    elif 'mse' in args['losstype']:
        if args['probtype'] in ['predopt_nonconvexqp', 'convexqp', 'nonconvexqp', 'qcqp', 'bilinear',  'portfolio_qcqp', 'predopt_portfolio_qcqp', 'predopt_bilinear'] + ['knapsack'] + ['portfolio'] + ['predopt_portfolio'] +  ['predopt_knapsack']:
            loss = (Ygt.to(DEVICE)-Y.to(DEVICE)).pow(2).mean(dim=1)
        elif 'acopf' in args['probtype']:
            pg_loss = (Ygt['pg'].to(DEVICE)-Y['pg'].to(DEVICE)).pow(2).mean(dim=1)
            qg_loss = (Ygt['qg'].to(DEVICE)-Y['qg'].to(DEVICE)).pow(2).mean(dim=1)
            vm_loss = (Ygt['vm'].to(DEVICE)-Y['vm'].to(DEVICE)).pow(2).mean(dim=1)
            dva_loss = (Ygt['dva'].to(DEVICE)-Y['dva'].to(DEVICE)).pow(2).mean(dim=1)
            loss = 0.25*(pg_loss + qg_loss + vm_loss + dva_loss)
        else:
            raise NotImplementedError
        eq_viols = eq_val.pow(2)
        ineq_viols = (torch.clamp(ineq_val, min=0.)).pow(2)

    if 'p' in args['losstype'] or args['losstype'] == 'ld':
        eq_term = (lamh*eq_viols).mean(dim=1).mean()
        ineq_term = (lamg*ineq_viols).mean(dim=1).mean()
        total_loss = loss + eq_term + ineq_term
    else:
        total_loss = loss
    return total_loss

### VDVF
def featGen(x, feat_Gen_Net):
    features = feat_Gen_Net(x).to(DEVICE)
    return features

class BaselineDataSet(Dataset):
    def __init__(self, X, Y):
        super().__init__()
        self.X = X
        self.Y = Y
        try:
            self.nex = self.X.shape[0]
        except:
            self.nex = self.X["pd"].shape[0]

    def __len__(self):
        return self.nex

    def __getitem__(self, idx):
        if isinstance(self.X,dict):
            x = {k:v[idx] for k,v in self.X.items()}
            y = {k:v[idx] for k,v in self.Y.items()}
            return x, y
        else:
            return self.X[idx], self.Y[idx]


def update_lamda(train_loader, net, feature_Generator, data, lamg, lamh, args):
    torch.set_grad_enabled(False)
    net.eval()
    eq_viols, ineq_viols = [], []


    for Xtrain, _ in train_loader:
        Xfeat = Xtrain
        if 'predopt' in args['probtype'] and 'acopf' not in args['probtype']:
            Xfeat = featGen(Xtrain, feature_Generator).to(DEVICE)
        Yhat_train = net(Xfeat.double())
        ineq_val = data.ineq_resid(Xtrain,Yhat_train)
        eq_val = data.eq_resid(Xtrain,Yhat_train)
        eq_viol = eq_val.abs()
        ineq_viol = torch.clamp(ineq_val, min=0.)
        eq_viols.append(eq_viol)
        ineq_viols.append(ineq_viol)

    eq_viols = torch.cat(eq_viols, dim=0).mean(dim=0)
    ineq_viols = torch.cat(ineq_viols, dim=0).mean(dim=0)

    lamg = lamg + args['ldstepsize']*ineq_viols.mean()
    lamh = lamh + args['ldstepsize']*eq_viols

    net.train()
    torch.set_grad_enabled(True)

    # print("update lambda lamg", lamg.max().item(), "lamh", lamh.max().item(), flush=True)
    # print("Lambda h size: ", lamh.size())
    # print("Lambda g size: ", lamg.size())

    return lamg, lamh


if __name__=='__main__':
    main()
